{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Create PyTorchJob using Kubeflow Training SDK"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "This is a sample for Kubeflow Training SDK `kubeflow-training`.\n",
    "\n",
    "The notebook shows how to use Kubeflow Training SDK to create, get, wait, check and delete PyTorchJob."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Install Kubeflow Training Python SDKs\n",
    "\n",
    "You need to install Kubeflow Training SDK to run this Notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO (andreyvelich): Change to release version when SDK with the new APIs is published.\n",
    "!pip install git+https://github.com/kubeflow/training-operator.git#subdirectory=sdk/python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from kubernetes.client import V1PodTemplateSpec\n",
    "from kubernetes.client import V1ObjectMeta\n",
    "from kubernetes.client import V1PodSpec\n",
    "from kubernetes.client import V1Container\n",
    "\n",
    "from kubeflow.training import KubeflowOrgV1ReplicaSpec\n",
    "from kubeflow.training import KubeflowOrgV1PyTorchJob\n",
    "from kubeflow.training import KubeflowOrgV1PyTorchJobSpec\n",
    "from kubeflow.training import KubeflowOrgV1RunPolicy\n",
    "from kubeflow.training import TrainingClient\n",
    "\n",
    "from kubeflow.training import constants"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Define PyTorchJob"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "The demo only creates a worker of PyTorchJob to run mnist sample."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "name = \"pytorch-dist-mnist-gloo\"\n",
    "namespace = \"kubeflow-user-example-com\"\n",
    "container_name = \"pytorch\"\n",
    "\n",
    "container = V1Container(\n",
    "    name=container_name,\n",
    "    image=\"gcr.io/kubeflow-ci/pytorch-dist-mnist-test:v1.0\",\n",
    "    args=[\"--backend\", \"gloo\"],\n",
    ")\n",
    "\n",
    "replica_spec = KubeflowOrgV1ReplicaSpec(\n",
    "    replicas=1,\n",
    "    restart_policy=\"OnFailure\",\n",
    "    template=V1PodTemplateSpec(\n",
    "        metadata=V1ObjectMeta(\n",
    "            name=name,\n",
    "            namespace=namespace,\n",
    "            annotations={\n",
    "                \"sidecar.istio.io/inject\": \"false\"\n",
    "            }\n",
    "        ),\n",
    "        spec=V1PodSpec(\n",
    "            containers=[container]\n",
    "        )\n",
    "    )\n",
    ")\n",
    "\n",
    "pytorchjob = KubeflowOrgV1PyTorchJob(\n",
    "    api_version=constants.API_VERSION,\n",
    "    kind=constants.PYTORCHJOB_KIND,\n",
    "    metadata=V1ObjectMeta(name=name, namespace=namespace),\n",
    "    spec=KubeflowOrgV1PyTorchJobSpec(\n",
    "        run_policy=KubeflowOrgV1RunPolicy(clean_pod_policy=\"None\"),\n",
    "        pytorch_replica_specs={\n",
    "            \"Master\": replica_spec,\n",
    "            \"Worker\": replica_spec\n",
    "        },\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Create PyTorchJob\n",
    "\n",
    "You have to create Training Client to deploy your PyTorchJob in you cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "PyTorchJob kubeflow-user-example-com/pytorch-dist-mnist-gloo has been created\n"
     ]
    }
   ],
   "source": [
    "# Namespace will be reused in every APIs.\n",
    "training_client = TrainingClient(namespace=namespace)\n",
    "\n",
    "# If `job_kind` is not set in `TrainingClient`, we need to set it for each API.\n",
    "training_client.create_job(pytorchjob, job_kind=constants.PYTORCHJOB_KIND)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Get the Created PyTorchJob\n",
    "\n",
    "You can verify the created PyTorchJob name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'pytorch-dist-mnist-gloo'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_client.get_job(name, job_kind=constants.PYTORCHJOB_KIND).metadata.name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Get the PyTorchJob Conditions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'last_transition_time': datetime.datetime(2023, 9, 8, 21, 14, 59, tzinfo=tzutc()),\n",
       "  'last_update_time': datetime.datetime(2023, 9, 8, 21, 14, 59, tzinfo=tzutc()),\n",
       "  'message': 'PyTorchJob pytorch-dist-mnist-gloo is created.',\n",
       "  'reason': 'PyTorchJobCreated',\n",
       "  'status': 'True',\n",
       "  'type': 'Created'},\n",
       " {'last_transition_time': datetime.datetime(2023, 9, 8, 21, 15, 45, tzinfo=tzutc()),\n",
       "  'last_update_time': datetime.datetime(2023, 9, 8, 21, 15, 45, tzinfo=tzutc()),\n",
       "  'message': 'PyTorchJob pytorch-dist-mnist-gloo is running.',\n",
       "  'reason': 'JobRunning',\n",
       "  'status': 'True',\n",
       "  'type': 'Running'}]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_client.get_job_conditions(name=name, job_kind=constants.PYTORCHJOB_KIND)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Wait Until PyTorchJob Finishes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "NAME                           STATE                TIME\n",
      "pytorch-dist-mnist-gloo        Running              2023-09-08 21:15:45+00:00\n",
      "pytorch-dist-mnist-gloo        Running              2023-09-08 21:15:45+00:00\n",
      "pytorch-dist-mnist-gloo        Succeeded            2023-09-08 21:26:44+00:00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Succeeded number of replicas: 1\n"
     ]
    }
   ],
   "source": [
    "pytorchjob = training_client.wait_for_job_conditions(name=name, job_kind=constants.PYTORCHJOB_KIND)\n",
    "\n",
    "print(f\"Succeeded number of replicas: {pytorchjob.status.replica_statuses['Master'].succeeded}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Verify if PyTorchJob is Succeeded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_client.is_job_succeeded(name=name, job_kind=constants.PYTORCHJOB_KIND)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Get the PyTorchJob Training Logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The logs of pod pytorch-dist-mnist-gloo-master-0:\n",
      " Using distributed PyTorch with gloo backend\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Processing...\n",
      "Done!\n",
      "Train Epoch: 1 [0/60000 (0%)]\tloss=2.3000\n",
      "Train Epoch: 1 [640/60000 (1%)]\tloss=2.2135\n",
      "Train Epoch: 1 [1280/60000 (2%)]\tloss=2.1704\n",
      "Train Epoch: 1 [1920/60000 (3%)]\tloss=2.0766\n",
      "Train Epoch: 1 [2560/60000 (4%)]\tloss=1.8679\n",
      "Train Epoch: 1 [3200/60000 (5%)]\tloss=1.4135\n",
      "Train Epoch: 1 [3840/60000 (6%)]\tloss=1.0003\n",
      "Train Epoch: 1 [4480/60000 (7%)]\tloss=0.7762\n",
      "Train Epoch: 1 [5120/60000 (9%)]\tloss=0.4598\n",
      "Train Epoch: 1 [5760/60000 (10%)]\tloss=0.4860\n",
      "Train Epoch: 1 [6400/60000 (11%)]\tloss=0.4389\n",
      "Train Epoch: 1 [7040/60000 (12%)]\tloss=0.4084\n",
      "Train Epoch: 1 [7680/60000 (13%)]\tloss=0.4602\n",
      "Train Epoch: 1 [8320/60000 (14%)]\tloss=0.4289\n",
      "Train Epoch: 1 [8960/60000 (15%)]\tloss=0.3990\n",
      "Train Epoch: 1 [9600/60000 (16%)]\tloss=0.3852\n",
      "Train Epoch: 1 [10240/60000 (17%)]\tloss=0.2984\n",
      "Train Epoch: 1 [10880/60000 (18%)]\tloss=0.5029\n",
      "Train Epoch: 1 [11520/60000 (19%)]\tloss=0.5236\n",
      "Train Epoch: 1 [12160/60000 (20%)]\tloss=0.3378\n",
      "Train Epoch: 1 [12800/60000 (21%)]\tloss=0.3674\n",
      "Train Epoch: 1 [13440/60000 (22%)]\tloss=0.4508\n",
      "Train Epoch: 1 [14080/60000 (23%)]\tloss=0.3034\n",
      "Train Epoch: 1 [14720/60000 (25%)]\tloss=0.3574\n",
      "Train Epoch: 1 [15360/60000 (26%)]\tloss=0.3313\n",
      "Train Epoch: 1 [16000/60000 (27%)]\tloss=0.4405\n",
      "Train Epoch: 1 [16640/60000 (28%)]\tloss=0.3642\n",
      "Train Epoch: 1 [17280/60000 (29%)]\tloss=0.3172\n",
      "Train Epoch: 1 [17920/60000 (30%)]\tloss=0.2016\n",
      "Train Epoch: 1 [18560/60000 (31%)]\tloss=0.4978\n",
      "Train Epoch: 1 [19200/60000 (32%)]\tloss=0.3254\n",
      "Train Epoch: 1 [19840/60000 (33%)]\tloss=0.1191\n",
      "Train Epoch: 1 [20480/60000 (34%)]\tloss=0.1905\n",
      "Train Epoch: 1 [21120/60000 (35%)]\tloss=0.1408\n",
      "Train Epoch: 1 [21760/60000 (36%)]\tloss=0.3147\n",
      "Train Epoch: 1 [22400/60000 (37%)]\tloss=0.1505\n",
      "Train Epoch: 1 [23040/60000 (38%)]\tloss=0.2898\n",
      "Train Epoch: 1 [23680/60000 (39%)]\tloss=0.4685\n",
      "Train Epoch: 1 [24320/60000 (41%)]\tloss=0.2158\n",
      "Train Epoch: 1 [24960/60000 (42%)]\tloss=0.1521\n",
      "Train Epoch: 1 [25600/60000 (43%)]\tloss=0.2248\n",
      "Train Epoch: 1 [26240/60000 (44%)]\tloss=0.2623\n",
      "Train Epoch: 1 [26880/60000 (45%)]\tloss=0.2335\n",
      "Train Epoch: 1 [27520/60000 (46%)]\tloss=0.2623\n",
      "Train Epoch: 1 [28160/60000 (47%)]\tloss=0.2126\n",
      "Train Epoch: 1 [28800/60000 (48%)]\tloss=0.1328\n",
      "Train Epoch: 1 [29440/60000 (49%)]\tloss=0.2779\n",
      "Train Epoch: 1 [30080/60000 (50%)]\tloss=0.0943\n",
      "Train Epoch: 1 [30720/60000 (51%)]\tloss=0.1285\n",
      "Train Epoch: 1 [31360/60000 (52%)]\tloss=0.2455\n",
      "Train Epoch: 1 [32000/60000 (53%)]\tloss=0.3396\n",
      "Train Epoch: 1 [32640/60000 (54%)]\tloss=0.1523\n",
      "Train Epoch: 1 [33280/60000 (55%)]\tloss=0.0916\n",
      "Train Epoch: 1 [33920/60000 (57%)]\tloss=0.1448\n",
      "Train Epoch: 1 [34560/60000 (58%)]\tloss=0.1989\n",
      "Train Epoch: 1 [35200/60000 (59%)]\tloss=0.2183\n",
      "Train Epoch: 1 [35840/60000 (60%)]\tloss=0.0638\n",
      "Train Epoch: 1 [36480/60000 (61%)]\tloss=0.1373\n",
      "Train Epoch: 1 [37120/60000 (62%)]\tloss=0.1147\n",
      "Train Epoch: 1 [37760/60000 (63%)]\tloss=0.2355\n",
      "Train Epoch: 1 [38400/60000 (64%)]\tloss=0.0636\n",
      "Train Epoch: 1 [39040/60000 (65%)]\tloss=0.1065\n",
      "Train Epoch: 1 [39680/60000 (66%)]\tloss=0.1599\n",
      "Train Epoch: 1 [40320/60000 (67%)]\tloss=0.1090\n",
      "Train Epoch: 1 [40960/60000 (68%)]\tloss=0.1774\n",
      "Train Epoch: 1 [41600/60000 (69%)]\tloss=0.2307\n",
      "Train Epoch: 1 [42240/60000 (70%)]\tloss=0.0736\n",
      "Train Epoch: 1 [42880/60000 (71%)]\tloss=0.1553\n",
      "Train Epoch: 1 [43520/60000 (72%)]\tloss=0.2793\n",
      "Train Epoch: 1 [44160/60000 (74%)]\tloss=0.1428\n",
      "Train Epoch: 1 [44800/60000 (75%)]\tloss=0.1179\n",
      "Train Epoch: 1 [45440/60000 (76%)]\tloss=0.1205\n",
      "Train Epoch: 1 [46080/60000 (77%)]\tloss=0.0767\n",
      "Train Epoch: 1 [46720/60000 (78%)]\tloss=0.1946\n",
      "Train Epoch: 1 [47360/60000 (79%)]\tloss=0.0703\n",
      "Train Epoch: 1 [48000/60000 (80%)]\tloss=0.2094\n",
      "Train Epoch: 1 [48640/60000 (81%)]\tloss=0.1378\n",
      "Train Epoch: 1 [49280/60000 (82%)]\tloss=0.0950\n",
      "Train Epoch: 1 [49920/60000 (83%)]\tloss=0.1066\n",
      "Train Epoch: 1 [50560/60000 (84%)]\tloss=0.1182\n",
      "Train Epoch: 1 [51200/60000 (85%)]\tloss=0.1455\n",
      "Train Epoch: 1 [51840/60000 (86%)]\tloss=0.0684\n",
      "Train Epoch: 1 [52480/60000 (87%)]\tloss=0.0241\n",
      "Train Epoch: 1 [53120/60000 (88%)]\tloss=0.2626\n",
      "Train Epoch: 1 [53760/60000 (90%)]\tloss=0.0922\n",
      "Train Epoch: 1 [54400/60000 (91%)]\tloss=0.1301\n",
      "Train Epoch: 1 [55040/60000 (92%)]\tloss=0.1921\n",
      "Train Epoch: 1 [55680/60000 (93%)]\tloss=0.0346\n",
      "Train Epoch: 1 [56320/60000 (94%)]\tloss=0.0358\n",
      "Train Epoch: 1 [56960/60000 (95%)]\tloss=0.0767\n",
      "Train Epoch: 1 [57600/60000 (96%)]\tloss=0.1167\n",
      "Train Epoch: 1 [58240/60000 (97%)]\tloss=0.1932\n",
      "Train Epoch: 1 [58880/60000 (98%)]\tloss=0.2062\n",
      "Train Epoch: 1 [59520/60000 (99%)]\tloss=0.0647\n",
      "\n",
      "accuracy=0.9669\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "training_client.get_job_logs(name=name, job_kind=constants.PYTORCHJOB_KIND)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Delete the PyTorchJob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "PyTorchJob kubeflow-user-example-com/pytorch-dist-mnist-gloo has been deleted\n"
     ]
    }
   ],
   "source": [
    "training_client.delete_job(name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
